package mgoDB

import (
	"context"
	"errors"
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
	"go.mongodb.org/mongo-driver/mongo/options"
	"time"
)

// MgoRepository 数据库仓储
type MgoRepository struct {
	*CollectionContent
	timeout time.Duration //超时时间(秒)
}

// Add 插入单个
// entity  结构体
func (m *MgoRepository) Add(entity interface{}) (id interface{}, err error) {
	ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
	defer cancel()

	aResult, err := m.collection.InsertOne(ctx, entity)
	if err != nil {
		return nil, err
	}
	return aResult.InsertedID, nil
}

// AddMany 插入多个
// entity 结构体切片
func (m *MgoRepository) AddMany(entity ...interface{}) (ids []interface{}, err error) {

	ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
	defer cancel()

	// 插入多条数据
	imResult, err := m.collection.InsertMany(ctx, entity)
	if err != nil {
		return nil, err
	}

	return imResult.InsertedIDs, err
}

// remove 删除单个
// id 主键
func (m *MgoRepository) remove(id interface{}) (count int64, err error) {
	ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
	defer cancel()

	dResult, err := m.collection.DeleteOne(ctx, bson.M{"_id": id})
	if err != nil {
		return 0, err
	}
	return dResult.DeletedCount, err
}

// Remove 删除单个
// id 主键切片
func (m *MgoRepository) Remove(ids ...interface{}) (count int64, err error) {

	ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
	defer cancel()

	//c.Find(bson.M{"name": bson.M{"$in": []string{"Jimmy Kuu", "Tracy Yu"}}})
	dResult, err := m.collection.DeleteMany(ctx, bson.M{"_id": bson.M{"$in": ids}})
	if err != nil {
		return 0, err
	}
	return dResult.DeletedCount, err
}

// Set filterKey   过滤的字段名称
// filterValue     过滤的字段值
// entity          更新的结构体（支持局部更新）
func (m *MgoRepository) Set(filterKey string, filterValue interface{}, entity interface{}) (count int64, err error) {

	ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
	defer cancel()

	update := bson.M{"$set": entity}
	uResult, err := m.collection.UpdateOne(ctx, bson.M{filterKey: filterValue}, update)
	if err != nil {
		return 0, err
	}
	return uResult.ModifiedCount, err
}

// FindById 单条件查询单个[id]
// id 过滤的字段值
// result      查询的结果（指针）
func (m *MgoRepository) FindById(id interface{},filterField []string, result interface{}) (err error) {
	return m.Find("_id", id, filterField, result)
}

// FindByIds 单条件查询多个[ids]
// filterValue 过滤的字段值
// result      查询的结果（指针）
func (m *MgoRepository) FindByIds(ids []interface{},filterField []string, result interface{}) (err error) {
	filterMap := bson.M{
		"_id": bson.M{
			"$in": ids,
		},
	}
	return m.FindMany(filterMap,  filterField, result)
}

// Find 单条件查询单个
// filterKey   过滤的字段名称
// filterValue 过滤的字段值
// result      查询的结果（指针）
func (m *MgoRepository) Find(filterKey string, filterValue interface{},filterField []string, result interface{}) (err error) {
	ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
	defer cancel()
	filter := bson.M{filterKey: filterValue}
	if filterField != nil {
		projection := make(map[string]int,len(filterField))
		for _,val := range filterField {
			projection[val] = 1
		}
		opt := &options.FindOneOptions{Projection: projection}
		err = m.collection.FindOne(ctx, filter, opt).Decode(result)
	} else {
		err = m.collection.FindOne(ctx, filter).Decode(result)
	}
	if err != nil {
		return err
	}
	return err
}

// FindMany 多条件查询
// filterMap   过滤的条件集合
// result      查询的结果切片（指针）
func (m *MgoRepository) FindMany(filterMap map[string]interface{},filterField []string, result interface{}) error {

	ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
	defer cancel()

	var (
		cur *mongo.Cursor
		err error
	)
	if filterField != nil {
		projection := make(map[string]int,len(filterField))
		for _,val := range filterField {
			projection[val] = 1
		}
		opt := &options.FindOptions{ Projection: projection }
		cur, err = m.collection.Find(ctx, filterMap,opt)
	} else {
		cur, err = m.collection.Find(ctx, filterMap)
	}
	if err != nil {
		return err
	}
	defer cur.Close(context.Background())

	if err := cur.Err(); err != nil {
		return err
	}
	err = cur.All(context.Background(), result)
	if err != nil {
		return err
	}
	return err
}

// FindCount 查询总数
// filterMap 查询条件
func (m *MgoRepository) FindCount(filterMap map[string]interface{}) (count int64, err error) {

	ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
	defer cancel()

	count, err = m.collection.CountDocuments(ctx, filterMap)
	if err != nil {
		return count, err
	}
	return count, err
}

//FindPage 查询分页
// filterMap 查询条件
//	  模糊查询：map["name"] = primitive.Regex{Pattern: "深入"}
//    查询createtime>=3： bson.M{"createtime": bson.M{"$gte": 2}}
//    二级结构体查询: map["author.country"] = countryChina
// pageIndex 页数（从1开始）
// size 获取个数
// sortName 排序字段
// desc 是否倒序(1为正序，-1为倒序) 1 为最初时间读取 ， -1 为最新时间读取
// filterField 过滤需要的字段,nil 为全部 ，[]string{ “字段1”，“字段2”...}
// result 查询结果(结构体切片)
// isTotal 总数
func (m *MgoRepository) FindPage(filterMap map[string]interface{}, pageIndex, size int64, sortName string, desc int,filterField []string, result interface{}, isTotal bool) (totalCount int64, err error) {

	ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
	defer cancel()
	var sort bson.D
	if desc == 1 || desc == -1 {
		sort = bson.D{{sortName, desc}}
	} else {
		return 0,errors.New(" desc 必须为1或-1 ！")
	}
	if pageIndex < 1 {
		return 0,errors.New(" pageIndex 必须大于或等于 1！")
	}
	if size <= 0{
		return 0,errors.New(" size 必须大于 0！")
	}
	skip := (pageIndex - 1) * size
	findOptions := options.Find().SetSort(sort).SetSkip(skip).SetLimit(size)

	if filterField != nil{
		projection := make(map[string]int,len(filterField))
		for _,val := range filterField {
			projection[val] = 1
		}
		findOptions.Projection = projection
	}
	//获取数据
	cur, err := m.collection.Find(ctx, filterMap, findOptions)
	if err := cur.Err(); err != nil {
		return 0, err
	}

	//取出数据
	err = cur.All(context.Background(), result)
	cur.Close(context.Background())

	if err != nil {
		return 0, err
	}

	if isTotal {
		ctxCount, cancelCount := context.WithTimeout(context.Background(), m.timeout)
		defer cancelCount()

		//获取总数
		totalCount, err = m.collection.CountDocuments(ctxCount, filterMap)
		if err != nil {
			return 0, nil
		}

	} else {
		totalCount = 0
	}
	return totalCount, nil
}
